import torch  # Import PyTorch, a popular machine learning library
import torch.nn as nn  # Import the neural network module
import torch.optim as optim  # Import optimization algorithms
from torch.distributions import Categorical  # Import Categorical for probabilistic action sampling
import numpy as np  # Import NumPy for numerical computations
import gym  # Import OpenAI Gym for environment simulation

# Define Actor-Critic Network
class ActorCritic(nn.Module):  # Define the Actor-Critic model
    def __init__(self, state_dim, action_dim):  # Initialize with state and action dimensions
        super(ActorCritic, self).__init__()  # Call parent class constructor
        self.shared_layer = nn.Sequential(  # Shared network layers for feature extraction
            nn.Linear(state_dim, 128),  # Fully connected layer with 128 neurons
            nn.ReLU()  # ReLU activation function
        )
        self.actor = nn.Sequential(  # Define the actor (policy) network
            nn.Linear(128, action_dim),  # Fully connected layer to output action probabilities
            nn.Softmax(dim=-1)  # Softmax to ensure output is a probability distribution
        )
        self.critic = nn.Linear(128, 1)  # Define the critic (value) network to output state value

    def forward(self, state):  # Forward pass for the model
        shared = self.shared_layer(state)  # Pass state through shared layers
        action_probs = self.actor(shared)  # Get action probabilities from actor network
        state_value = self.critic(shared)  # Get state value from critic network
        return action_probs, state_value  # Return action probabilities and state value

# Memory to store experiences
class Memory:  # Class to store agent's experience
    def __init__(self):  # Initialize memory
        self.states = []  # List to store states
        self.actions = []  # List to store actions
        self.logprobs = []  # List to store log probabilities of actions
        self.rewards = []  # List to store rewards
        self.is_terminals = []  # List to store terminal state flags

    def clear(self):  # Clear memory after an update
        self.states = []  # Clear stored states
        self.actions = []  # Clear stored actions
        self.logprobs = []  # Clear stored log probabilities
        self.rewards = []  # Clear stored rewards
        self.is_terminals = []  # Clear terminal state flags

# PPO Agent
class PPO:  # Define the PPO agent
    def __init__(self, state_dim, action_dim, lr=0.002, gamma=0.99, eps_clip=0.2, K_epochs=4):
        self.policy = ActorCritic(state_dim, action_dim).to(device)  # Initialize the Actor-Critic model
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)  # Adam optimizer for parameter updates
        self.policy_old = ActorCritic(state_dim, action_dim).to(device)  # Copy of the policy for stability
        self.policy_old.load_state_dict(self.policy.state_dict())  # Synchronize parameters
        self.MseLoss = nn.MSELoss()  # Mean Squared Error loss for critic updates

        self.gamma = gamma  # Discount factor for rewards
        self.eps_clip = eps_clip  # Clipping parameter for PPO
        self.K_epochs = K_epochs  # Number of epochs for optimization

    def select_action(self, state, memory):
        state = torch.FloatTensor(state).to(device)  # Convert state to PyTorch tensor
        action_probs, _ = self.policy_old(state)  # Get action probabilities from old policy
        dist = Categorical(action_probs)  # Create a categorical distribution
        action = dist.sample()  # Sample an action from the distribution

        memory.states.append(state)  # Store state in memory
        memory.actions.append(action)  # Store action in memory
        memory.logprobs.append(dist.log_prob(action))  # Store log probability of the action

        return action.item()  # Return action as a scalar value

    def update(self, memory):
        # Convert memory to tensors
        old_states = torch.stack(memory.states).to(device).detach()  # Convert states to tensor
        old_actions = torch.stack(memory.actions).to(device).detach()  # Convert actions to tensor
        old_logprobs = torch.stack(memory.logprobs).to(device).detach()  # Convert log probabilities to tensor

        # Monte Carlo rewards
        rewards = []  # Initialize rewards list
        discounted_reward = 0  # Initialize discounted reward
        for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
            if is_terminal:  # If the state is terminal
                discounted_reward = 0  # Reset discounted reward
            discounted_reward = reward + (self.gamma * discounted_reward)  # Compute discounted reward
            rewards.insert(0, discounted_reward)  # Insert at the beginning of the list
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)  # Convert rewards to tensor
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)  # Normalize rewards

        # Update for K epochs
        for _ in range(self.K_epochs):
            # Get action probabilities and state values
            action_probs, state_values = self.policy(old_states)  # Get action probabilities and state values
            dist = Categorical(action_probs)  # Create a categorical distribution
            new_logprobs = dist.log_prob(old_actions)  # Compute new log probabilities of actions
            entropy = dist.entropy()  # Compute entropy for exploration

            # Calculate ratios
            ratios = torch.exp(new_logprobs - old_logprobs.detach())  # Compute probability ratios

            # Advantages
            advantages = rewards - state_values.detach().squeeze()  # Compute advantages

            # Surrogate loss
            surr1 = ratios * advantages  # Surrogate loss 1
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages  # Clipped loss
            loss_actor = -torch.min(surr1, surr2).mean()  # Actor loss

            # Critic loss
            loss_critic = self.MseLoss(state_values.squeeze(), rewards)  # Critic loss

            # Total loss
            loss = loss_actor + 0.5 * loss_critic - 0.01 * entropy.mean()  # Combined loss

            # Update policy
            self.optimizer.zero_grad()  # Zero the gradient buffers
            loss.backward()  # Backpropagate loss
            self.optimizer.step()  # Perform a parameter update

        # Update old policy
        self.policy_old.load_state_dict(self.policy.state_dict())  # Copy new policy parameters to old policy

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available
env = gym.make("CartPole-v1")  # Initialize CartPole environment
state_dim = env.observation_space.shape[0]  # Dimension of state space
action_dim = env.action_space.n  # Number of possible actions
lr = 0.002  # Learning rate
gamma = 0.99  # Discount factor
eps_clip = 0.2  # Clipping parameter
K_epochs = 4  # Number of epochs for policy update
max_episodes = 1000  # Maximum number of episodes
max_timesteps = 300  # Maximum timesteps per episode

# PPO Training
ppo = PPO(state_dim, action_dim, lr, gamma, eps_clip, K_epochs)  # Initialize PPO agent
memory = Memory()  # Initialize memory

for episode in range(1, max_episodes + 1):  # Loop over episodes
    state, _ = env.reset()  # Reset environment
    total_reward = 0  # Initialize total reward

    for t in range(max_timesteps):  # Loop over timesteps
        action = ppo.select_action(state, memory)  # Select action using PPO
        state, reward, done, _, _ = env.step(action)  # Take action and observe results

        memory.rewards.append(reward)  # Store reward in memory
        memory.is_terminals.append(done)  # Store terminal state flag in memory
        total_reward += reward  # Accumulate total reward

        if done:  # If episode is done
            break  # Exit loop

    ppo.update(memory)  # Update PPO agent
    memory.clear()  # Clear memory

    print(f"Episode {episode}, Total Reward: {total_reward}")  # Print episode statistics

env.close()  # Close the environment
